"""
Patent-to-Code Embedding Demo
=============================
This script shows how we embed patent text and software code 
using DeepSeek-Coder-1.3B and calculate their semantic similarity.

The embeddings capture the technical meaning of both patents and code,
enabling us to measure how closely a code contribution relates to a patent.

Author: Sergio Petralia
Requirements: torch, transformers

Note: First run downloads the model (~2.5GB) from HuggingFace Hub.

To run (CPU is sufficient for this demo):
    python embedding_demo.py
"""

import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


# =============================================================================
# Sample Data
# =============================================================================

# Sample patent abstract about batch processing for ML inference
SAMPLE_PATENT = """
A method and system for batching inputs to a machine learning model. 
The system receives a stream of inference requests, each containing input 
data for processing by the model. The requests are added to a queue and 
when the queue size reaches a threshold batch size, the inputs are combined 
into a batched tensor and processed together by the neural network for 
improved computational efficiency.
"""

# Sample code implementing batch inference - clearly related to the patent
RELATED_CODE = """
class BatchInferenceServer:
    def __init__(self, model, max_batch_size=32):
        self.model = model
        self.max_batch_size = max_batch_size
        self.request_queue = []
    
    def add_request(self, input_tensor):
        self.request_queue.append(input_tensor)
        if len(self.request_queue) >= self.max_batch_size:
            return self.process_batch()
        return None
    
    def process_batch(self):
        if not self.request_queue:
            return []
        # Combine inputs into batched tensor
        batched_input = torch.stack(self.request_queue)
        self.request_queue = []
        # Run batch inference
        with torch.no_grad():
            outputs = self.model(batched_input)
        return outputs
"""

# Completely unrelated code - SQL queries for employee payroll
UNRELATED_CODE = """
-- Employee payroll queries
SELECT employee_id, first_name, last_name, salary
FROM employees
WHERE department = 'Sales'
ORDER BY salary DESC;

UPDATE employees 
SET salary = salary * 1.05
WHERE hire_date < '2020-01-01';

INSERT INTO payroll_history (employee_id, pay_date, amount)
SELECT employee_id, CURRENT_DATE, salary 
FROM employees
WHERE status = 'active';
"""


# =============================================================================
# Model Loading
# =============================================================================

def load_model(device="cpu"):

    print("Loading DeepSeek-Coder-1.3B model...")
    print(f"Device: {device}")
    
    model_name = "deepseek-ai/deepseek-coder-1.3b-base"
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Load model (float32 for CPU compatibility)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.float32  # Use float32 for CPU
    )
    model = model.to(device)
    model.eval()
    
    print("✓ Model loaded successfully\n")
    return model, tokenizer


# =============================================================================
# Embedding Function
# =============================================================================

def get_embedding(text, model, tokenizer, device="cpu", max_length=512):
    """
    Generate an embedding vector for the given text using mean pooling.
    """
    # Tokenize
    inputs = tokenizer(
        text,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    ).to(device)
    
    # Forward pass
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
        # Get last layer hidden states
        hidden_states = outputs.hidden_states[-1]
    
    # Mean pooling: average over all non-padding tokens
    attention_mask = inputs["attention_mask"].unsqueeze(-1)
    masked_hidden = hidden_states * attention_mask
    sum_hidden = masked_hidden.sum(dim=1)
    count = attention_mask.sum(dim=1).clamp(min=1)
    embedding = (sum_hidden / count).squeeze()
    
    return embedding.cpu().numpy()


# =============================================================================
# Similarity Calculation
# =============================================================================

def cosine_similarity(vec1, vec2):
    """
    Calculate cosine similarity between two vectors.
    """
    dot_product = np.dot(vec1, vec2)
    norm1 = np.linalg.norm(vec1)
    norm2 = np.linalg.norm(vec2)
    
    if norm1 == 0 or norm2 == 0:
        return 0.0
    
    return dot_product / (norm1 * norm2)


# =============================================================================
# Main Demo
# =============================================================================

def main():
    print("=" * 70)
    print("Patent-to-Code Embedding Demo")
    print("=" * 70)
    print()
    
    # Use CPU for this demo (works without GPU)
    device = "cpu"
    
    # Load model
    model, tokenizer = load_model(device)
    
    # Generate embeddings
    print("Generating embeddings...")
    print("-" * 40)
    
    print("Embedding patent abstract...", end=" ")
    patent_embedding = get_embedding(SAMPLE_PATENT, model, tokenizer, device)
    print(f"Done (dim={len(patent_embedding)})")
    
    print("Embedding related code...", end=" ")
    related_code_embedding = get_embedding(RELATED_CODE, model, tokenizer, device)
    print(f"Done (dim={len(related_code_embedding)})")
    
    print("Embedding unrelated code...", end=" ")
    unrelated_embedding = get_embedding(UNRELATED_CODE, model, tokenizer, device)
    print(f"Done (dim={len(unrelated_embedding)})")
    
    print()
    print("-" * 40)
    print("Calculating similarities...")
    print("-" * 40)
    
    # Calculate similarities
    sim_related = cosine_similarity(patent_embedding, related_code_embedding)
    sim_unrelated = cosine_similarity(patent_embedding, unrelated_embedding)
    
    print()
    print("Results:")
    print(f"  Patent ↔ Related code:   {sim_related:.4f}")
    print(f"  Patent ↔ Unrelated code: {sim_unrelated:.4f}")
    print()
    

    print("=" * 70)
    print("Demo complete!")
    print("=" * 70)


if __name__ == "__main__":
    main()
